{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2d579aba-7439-42a6-aae4-cf1983095dee",
   "metadata": {},
   "source": [
    "# Example 05 - TorchSigWideband with YOLOv8 Detector (Creates and Populates Image/Label Directories)\n",
    "This notebook showcases using the Torchsig Wideband dataset to train a YOLOv8 model.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40a026bd-f096-47f3-a262-48ab5defe23e",
   "metadata": {},
   "source": [
    "## Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4bff6d5-4b2d-4db2-97a0-f45843e7cc60",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torchsig.datasets.datamodules import WidebandDataModule\n",
    "from torch.utils.data import DataLoader\n",
    "from torchsig.utils.dataset import collate_fn\n",
    "from torchsig.datasets.torchsig_narrowband import TorchSigNarrowband\n",
    "from torchsig.datasets.torchsig_wideband import TorchSigWideband\n",
    "from torchsig.datasets.signal_classes import torchsig_signals\n",
    "from torchsig.transforms.target_transforms import DescToListTuple, ListTupleToYOLO\n",
    "from torchsig.transforms.transforms import Spectrogram, SpectrogramImage, Normalize, Compose, Identity\n",
    "import pytorch_lightning as pl\n",
    "import numpy as np\n",
    "\n",
    "from ultralytics import YOLO\n",
    "import cv2\n",
    "import yaml\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd94880c-2c78-45e6-9c22-302c433fb73a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a46c59d-e20c-435b-8a22-249a56bdd810",
   "metadata": {},
   "source": [
    "## Instantiate Wideband Dataset\n",
    "After generating the Wideband dataset (see `03_example_wideband_dataset.ipynb`), we can instantiate it with the needed transforms. Change `root` to dataset path.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "993f59e9-6aa6-4379-8cb7-9b1578eaf7fb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "root = './datasets/wideband'\n",
    "fft_size = 512\n",
    "num_classes = len(torchsig_signals.class_list)\n",
    "impaired = True\n",
    "num_workers = 4\n",
    "batch_size = 1\n",
    "qa = True\n",
    "\n",
    "transform = Compose([\n",
    "    Normalize(norm=np.inf, flatten=True),\n",
    "    Spectrogram(nperseg=fft_size, noverlap=0, nfft=fft_size, detrend=None),\n",
    "    Normalize(norm=np.inf, flatten=True),\n",
    "    SpectrogramImage(),    \n",
    "])\n",
    "\n",
    "target_transform = Compose([\n",
    "    DescToListTuple(),\n",
    "    ListTupleToYOLO()\n",
    "])\n",
    "\n",
    "# Instantiate the TorchSigWideband Dataset\n",
    "datamodule = WidebandDataModule(\n",
    "    root=root,\n",
    "    impaired=impaired,\n",
    "    qa=qa,\n",
    "    fft_size=fft_size,\n",
    "    num_classes=num_classes,\n",
    "    transform=transform,\n",
    "    target_transform=target_transform,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=num_workers\n",
    ")\n",
    "datamodule.prepare_data()\n",
    "datamodule.setup(\"fit\")\n",
    "\n",
    "wideband_train = datamodule.train\n",
    "wideband_val = datamodule.val\n",
    "\n",
    "\n",
    "# Retrieve a sample and print out information\n",
    "idx = np.random.randint(len(wideband_val))\n",
    "data, label = wideband_val[idx]\n",
    "print(\"Training Dataset length: {}\".format(len(wideband_train)))\n",
    "print(\"Validation Dataset length: {}\".format(len(wideband_val)))\n",
    "print(\"Data shape: {}\\n\\t\".format(data.shape))\n",
    "print(f\"Label length: {len(label)}\", end=\"\\n\\t\")\n",
    "print(*label, sep=\"\\n\\t\")\n",
    "print(f\"Label: {type(label)} of {type(label[0])} \\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2938efda-7b7f-42e9-8bcf-022ebcaf2d32",
   "metadata": {},
   "source": [
    "## Format Dataset for YOLO\n",
    "Next, the datasets are rewritten to disk that is Ultralytics YOLO compatible. See [Ultralytics: Train Custom Data - Organize Directories](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#23-organize-directories) to learn more. \n",
    "\n",
    "Additionally, create a yaml file for dataset configuration. See [Ultralytics: Train Custom Data - Create dataset.yaml](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#21-create-datasetyaml)\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0881b3e7-f145-41fb-a6fd-4b4176cc967b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# method to output .png images and .txt label files in YOLO structure from wideband\n",
    "def prepare_data(dataset: TorchSigWideband, output: str, train: bool, impaired: bool) -> None:\n",
    "    output_root = os.path.join(output, \"wideband_yolo\")\n",
    "    os.makedirs(output_root, exist_ok=True)\n",
    "    impaired = \"impaired\" if impaired else \"clean\"\n",
    "    train = \"train\" if train else \"val\"\n",
    "    \n",
    "    label_dir = os.path.join(output_root, impaired, \"labels\", train)\n",
    "    image_dir = os.path.join(output_root, impaired, \"images\", train)\n",
    "    os.makedirs(label_dir, exist_ok=True)\n",
    "    os.makedirs(image_dir, exist_ok=True)\n",
    "    \n",
    "    for i in tqdm(range(len(dataset))):\n",
    "        image, labels = dataset[i]\n",
    "        filename_base = str(i).zfill(10)\n",
    "        label_filename = os.path.join(label_dir, filename_base) + \".txt\"\n",
    "        image_filename = os.path.join(image_dir, filename_base) + \".png\"\n",
    "        \n",
    "        with open(label_filename, \"w\") as f:\n",
    "            line = f\"\"\n",
    "            f.write(\"\\n\".join(f\"{x[0]} {x[1]} {x[2]} {x[3]} {x[4]}\" for x in labels))\n",
    "            \n",
    "        cv2.imwrite(image_filename, image, [cv2.IMWRITE_PNG_COMPRESSION, 9])\n",
    "        \n",
    "prepare_data(wideband_train, \"./datasets/wideband\", True, True)\n",
    "prepare_data(wideband_val, \"./datasets/wideband\", False, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "275ab1a3-3e66-45bf-97d9-3323228bbcf3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# create dataset yaml file\n",
    "config_name = \"05_yolo.yaml\"\n",
    "classes = {v: k for v, k in enumerate(torchsig_signals.class_list)}\n",
    "classes[0] = 'signal'\n",
    "\n",
    "wideband_yaml_dict = dict(\n",
    "    path = \"./wideband/wideband_yolo\",\n",
    "    train = \"impaired/images/train\",\n",
    "    val = \"impaired/images/val\",\n",
    "    nc = num_classes,\n",
    "    names = classes\n",
    ")\n",
    "\n",
    "with open(config_name, 'w+') as f:\n",
    "    yaml.dump(wideband_yaml_dict, f, default_flow_style=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23aed0db-8a41-4095-9095-deba80ea94a1",
   "metadata": {},
   "source": [
    "## Instantiate YOLO Model\n",
    "Download desired YOLO model from [Ultralytics Models](https://docs.ultralytics.com/models/). We will use YOLOv8, specifically `yolov8x.pt`\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f02e39-bf6f-4a30-a041-575acdfab1af",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8x.pt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf7ad773-5469-4306-ac4d-4f8c30043cda",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "modelPath = \"yolov8x.pt\"\n",
    "\n",
    "model = YOLO(modelPath)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22fa62fc-01d9-46b7-9cd2-38643f54d1de",
   "metadata": {},
   "source": [
    "## Train\n",
    "Train YOLO. See [Ultralytics Train](https://docs.ultralytics.com/modes/train/#train-settings) for training hyperparameter options.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad7e5f57-3b5b-4074-b87c-6e18c7b973af",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results = model.train(\n",
    "    data=config_name, \n",
    "    epochs=5, \n",
    "    batch=batch_size,\n",
    "    imgsz=640,\n",
    "    device=0 if torch.cuda.is_available() else \"cpu\"\n",
    "    workers=1,\n",
    "    project=\"yolo\",\n",
    "    name=\"05_example\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bbd29a1-fa6a-4154-a93c-fc07349b9388",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results_img = cv2.imread(os.path.join(results.save_dir, \"results.png\"))\n",
    "plt.figure(figsize = (10,20))\n",
    "plt.imshow(results_img)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ad0c459-2367-4f7d-89ae-5f4cfd8f545f",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "Check model performance from training. From here, you can use the trained model to test on images `model([\"img1.png\", \"img2.png\",...])`\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c6f2681-b25e-4969-832c-aeb8f80b0f21",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "label = cv2.imread(os.path.join(results.save_dir, \"val_batch2_labels.jpg\"))\n",
    "pred = cv2.imread(os.path.join(results.save_dir, \"val_batch2_pred.jpg\"))\n",
    "\n",
    "f, ax = plt.subplots(1, 2, figsize=(15, 9))\n",
    "ax[0].imshow(label)\n",
    "ax[0].set_title(\"Label\")\n",
    "ax[1].imshow(pred)\n",
    "ax[1].set_title(\"Prediction\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
